Slim Learning Note

slim = tf.contrib.slim

Layers

for certain functions, assign default values to certain parameters

with slim.arg_scope([func1, func2, ....], arg1=val1, arg2=val2, ....)

https://www.tensorflow.org/api_docs/python/tf/contrib/layers

  • slim.conv2d
  • slim.max_pool2d
  • slim.avg_pool2d
  • slim.dropout
  • slim.batch_norm
  • slim.softmax
  • tf.repeat(inputs, repetitions, layer, args, *kwargs)

class Block(collections.namedtuple(‘Block’, [‘scope’, ‘unit_fn’, ‘args’])):
net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)

Load pretrained model

Note that initialization function must be associated with sess.

1
2
init_fn = slim.assign_from_checkpoint_fn(checkpoint_path, slim.get_variables_to_restore())
init_fn(sess)

Batch normalization

With slim:

1
2
3
4
5
6
7
8
9
batch_norm_params = {
# Decay for the moving averages.
'decay': batch_norm_decay,
# epsilon to prevent 0s in variance.
'epsilon': batch_norm_epsilon,
# collection containing update_ops.
'updates_collections': tf.GraphKeys.UPDATE_OPS,
}
slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm, normalizer_params=normalizer_params)

For tf.contrib.layers or tf.slim, when is_training=True, mean and variance based on each batch are used and moving_mean and moving_variance are updated if applicable. When is_training=False, loaded moving-mean an moving_variance are used.

To launch the update of moving_mean and moving_variance, special attention needs to be paid because this update operation is detached from gradient descent, which can be realized in the following ways.

The first method:

1
2
3
4
5
6
7
8
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
variables_to_train = _get_variables_to_train()
grads_and_vars = optimizer.compute_gradients(total_loss, variables_to_train)
grad_updates = optimizer.apply_gradients(grads_and_vars)
update_ops.append(grad_updates)
update_op = tf.group(*update_ops)
with tf.control_dependencies([update_op]):
train_op = tf.identity(total_loss)

The second method:

1
2
3
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)

The third method:

1
train_op = slim.learning.create_train_op(total_loss, optimizer)

Otherwise, one can set updates_collections=None in slim.batch_norm to force the updates in place, but that can have a speed penalty, especially in distributed settings.

However, when trained on small-scale datasets, using moving_mean and moving_variance in the test stage often leads to extremely poor performance (close to random guess). This is due to the code start which renders moving_mean/variance unstable. There are two ways to fix the cold-start issue:

  • in the testing stage, also set is_training=True, i.e., use the mean and variance based on each test batch.

  • decrease batch_norm running average decay from default 0.999 to something like 0.99, which can speed up the start-up. When tuning decay, there is a trade-off between warm-up speed and statistical accuracy. For small-scale datasets, warm-up may take exceedingly long time, e.g., 300 epochs.

without slim: tf.nn.batch_normalization, no moving_mean/variance

1
2
3
4
5
6
7
8
9
10
11
def batchnorm(bn_input):
with tf.variable_scope("batchnorm"):
# this block looks like it has 3 inputs on the graph unless we do this
bn_input = tf.identity(bn_input)

channels = bn_input.get_shape()[3]
offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02))
mean, variance = tf.nn.moments(bn_input, axes=[0, 1, 2], keep_dims=False)
normalized = tf.nn.batch_normalization(bn_input, mean, variance, offset, scale, variance_epsilon=1e-5)
return normalized

Utils

print all model variables

1
2
3
4
5
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # all the global variables
slim.get_model_variables() or tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)
#variables defined by slim (tf.contrib.framework.model_variable)
#excluding gradient variables
tf.trainable_variables() #excluding graident variables and batch_norm variables (moving_mean and moving_variance)

print regularization losses(weight decay) and other losses

1
2
3
slim.losses.get_regularization_losses()
slim.losses.get_losses() # losses except weight decay
slim.losses.get_total_loss(add_regularization_losses=False)